package com.beau.filter;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.cloud.context.config.annotation.RefreshScope;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.cloud.gateway.support.ipresolver.RemoteAddressResolver;
import org.springframework.core.Ordered;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Component;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;

import java.util.ArrayList;
import java.util.Collections;
import java.util.concurrent.ConcurrentHashMap;

/**
 * 限流过滤器
 */
@Component
@RefreshScope
public class RateLimitFilter implements GlobalFilter, Ordered {

    private final Logger logger = LoggerFactory.getLogger(RateLimitFilter.class);
    // 相同 ip 在 period 时间内(单位：分钟)， 可以发送 reqTimes 次请求
    @Value("${rate_limit.request.period}")
    private int period;

    @Value("${rate_limit.request.times}")
    private int reqTimes;

    // 记录请求数据，key ip， value 由请求时间组成的数组
    private ConcurrentHashMap<String, ArrayList<Long>> requestRecords = new ConcurrentHashMap<>();

    @Autowired
    private RemoteAddressResolver addressResolver;


    @Override
    public synchronized Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
        logger.info("在{}分钟内，单个ip可以执行{}次请求", period, reqTimes);
        String requestIP = addressResolver.resolve(exchange).getHostString();
        ArrayList<Long> requestRecord = requestRecords.get(requestIP);
        if (requestRecord == null) {
            requestRecord = new ArrayList<>();
            requestRecords.put(requestIP, requestRecord);
            // 方便后面计算使用
            requestRecord.add(0L);
        }
        requestRecord.add(System.currentTimeMillis());
        // 判断在 period 时间了，请求次数
        Long startTime = System.currentTimeMillis() - period * 60 * 1000;
        // 查找请求时间小于 starTime 的索引
        int index = Collections.binarySearch(requestRecord, startTime);
        // index 小于 0 ，表示没有找到对应元素，但是可以判断出大于该请求时间的索引
        if (index < 0) {
            index = -index - 1;
        }
        // 计算请求次数
        int count = requestRecord.size() - index;
        logger.info("单位时间内请求次数为 {}", count);
        // 请求超过限制
        if (count > reqTimes) {
            logger.info("请求超过次数限制");
            String data = "request up to the limit, please try later";
            exchange.getResponse().setStatusCode(HttpStatus.SEE_OTHER);
            DataBuffer wrap = exchange.getResponse().bufferFactory().wrap(data.getBytes());
            return exchange.getResponse().writeWith(Mono.just(wrap));
        }
        return chain.filter(exchange);
    }

    @Override
    public int getOrder() {
        return -1;
    }
}
